# import torch
# from torch import nn
# import torch.nn.functional as F
# import math
# class TransformerEncoderLayer(nn.Module):
#     def __init__(self, hidden_size, num_layers=1, num_heads=8, dropout=0.1, input_size=None):
#         super(TransformerEncoderLayer, self).__init__()
#         self.hidden_size = hidden_size
#         self.num_layers = num_layers
#         self.num_heads = num_heads
#         self.dropout = dropout
#         self.input_size = input_size
#
#         self.input_projection = None
#         self.transformer_encoder = None
#         if input_size is not None:
#             self._build_encoder(input_size)
#
#     def _build_encoder(self, input_size):
#         # Adjust input_size to be divisible by num_heads
#         if input_size % self.num_heads != 0:
#             adjusted_input_size = math.ceil(input_size / self.num_heads) * self.num_heads
#             print(f"Adjusting input_size from {input_size} to {adjusted_input_size} to be divisible by num_heads")
#             self.input_projection = nn.Linear(input_size, adjusted_input_size)
#             input_size = adjusted_input_size
#         else:
#             self.input_projection = None
#
#         encoder_layer = nn.TransformerEncoderLayer(
#             d_model=input_size,
#             nhead=self.num_heads,
#             dim_feedforward=self.hidden_size,
#             dropout=self.dropout
#         )
#         self.transformer_encoder = nn.TransformerEncoder(
#             encoder_layer,
#             num_layers=self.num_layers
#         )
#
#     def forward(self, x):
#         if self.transformer_encoder is None:
#             self.input_size = x.size(-1)
#             if self.input_size % self.num_heads != 0:
#                 adjusted_input_size = math.ceil(self.input_size / self.num_heads) * self.num_heads
#                 print(f"Adjusting input_size from {self.input_size} to {adjusted_input_size} to be divisible by num_heads")
#                 self.input_size = adjusted_input_size
#                 self.input_projection = nn.Linear(x.size(-1), self.input_size)
#             else:
#                 self.input_projection = None
#             self._build_encoder(self.input_size)
#
#         if self.input_projection is not None:
#             x = self.input_projection(x)
#
#         output = self.transformer_encoder(x)
#         return output
#
# class dbFcn(nn.Module):
#     g_p = 0
#     g_batch = 0
#
#     def set_val(self, n_epoch, n_batch):
#         self.g_epoch = n_epoch
#         self.g_batch = n_batch
#
#     def __init__(self):
#         super(dbFcn, self).__init__()
#         # 9.特征扁平化
#         self.flatten9 = nn.Flatten()
#         # Transformer层 input_size根据输入维度调整
#         self.transformer_encoder = TransformerEncoderLayer(hidden_size=512, num_layers=6)
#         # 全连接层
#         self.L10 = nn.Linear(12, 128)
#         self.L11 = nn.Linear(128, 64)
#         self.L12 = nn.Linear(64, 3)
#
#     def forward(self, x):
#         # 3层全连接
#         x = self.flatten9(x)
#         # print(x)
#         # exit()
#         x = x.unsqueeze(0)  # 添加批次维度
#         x = x.permute(0, 2, 1)
#         x = self.transformer_encoder(x)
#         x = x.permute(0, 1, 2)
#         # 将输出扁平化以传递给全连接层
#         x = x.squeeze(0)  # 去除批次维度
#         x = x.permute(1, 0)
#         x = self.flatten9(x)
#         x = self.L10(x)
#         x = self.L11(x)
#         x = self.L12(x)
#         x = F.softmax(x, dim=1)
#         return x
import torch
from torch import nn
import torch.nn.functional as F


class TransformerEncoderLayer(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, num_heads=8, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=input_size,
                nhead=num_heads,
                dim_feedforward=hidden_size,
                dropout=dropout
            ),
            num_layers=num_layers
        )

    def forward(self, x):
        output = self.transformer_encoder(x)
        return output


class dbFcn(nn.Module):
    g_p = 0
    g_batch = 0

    def set_val(self, n_epoch, n_batch):
        self.g_epoch = n_epoch
        self.g_batch = n_batch

    def __init__(self):
        super(dbFcn, self).__init__()
        # 9.特征扁平化
        self.flatten9 = nn.Flatten()
        # Transformer层
        self.transformer_encoder = TransformerEncoderLayer(input_size=128, hidden_size=512, num_layers=6)
        # 全连接层
        self.L10 = nn.Linear(12, 128)
        self.L11 = nn.Linear(128, 64)
        self.L12 = nn.Linear(64, 4)
    def forward(self, x):
        # 3层全连接
        x = self.flatten9(x)
        x = x.unsqueeze(0)  # 添加批次维度
        x = x.permute(0, 2, 1)
        x = self.transformer_encoder(x)
        x = x.permute(0, 1, 2)
        x = x.squeeze(0)  # 去除批次维度
        x = x.permute(1, 0)
        x = self.flatten9(x)
        x = self.L10(x)
        x = self.L11(x)
        x = self.L12(x)
        x = F.softmax(x, dim=1)
        return x
